from threading import Lock
from quant.utils import abstract_method, data_routing_key, partial_function
from quant.markets.util import MarketInfoEngine, RoutingEngine, Mapping, Mapping3, pick_handler, Timer, set_value_with_book, ChannelSelecting
from quant.const import *


class Markets:
    def __init__(self):
        from quant.exchanges import get_factory
        self._get_fac = get_factory

        self.info_engine = MarketInfoEngine()
        self.routing_engine = RoutingEngine()

        self.channels = Mapping()       # (evt, ex, symbol, freq) -> [Channel(), Channel()...]
        self.handlers = Mapping()       # (evt, ex, symbol, freq) -> Handler()
        self.order_books = Mapping3()   # (ex, symbol) -> OrderBook()
        self.mids = Mapping3()          # (ex, symbol) -> mid_price

        self.timer = Timer()            # timer driven by data receiving. local ts.
        self.now = None                 # time defined by data receiving moment.

        self.channel_selecting = ChannelSelecting(self)

        self.timer.register(partial_function(set_value_with_book, self), 1)
        self.lock = Lock()

    def add_market(self, event, exchange, symbol, frequency=DataFrequency.Normal):
        with self.lock:
            factory = self._get_fac(exchange)
            channel = factory.channel_cls(event, exchange, symbol, frequency, self)
            channel.connect()
            li = self.channels.set_default(event, exchange, symbol, frequency, [])
            li.append(channel)

    def remove_market(self, event, exchange, symbol, frequency=DataFrequency.Normal):
        with self.lock:
            li = self.channels.get_default(event, exchange, symbol, frequency, [])
            for channel in li.copy():
                channel.disconnect()
                li.remove(channel)

    def feed_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        self.info_engine.push_raw(event, exchange, symbol, frequency, recv_time, raw)
        self.now = recv_time
        key = data_routing_key(event, exchange, symbol)
        handler = pick_handler(self, event, exchange, symbol, frequency)
        handler.process(key, recv_time, raw)
        self.timer.step(recv_time)

    def subscribe_raw(self, handler):
        self.info_engine.subscribe_raw(handler)

    def subscribe(self, event, exchange, symbol, handler):
        key = data_routing_key(event, exchange, symbol)
        self.routing_engine.subscribe(key, handler)

    def subscribe_all(self, handler):
        key = self.routing_engine.ALL_DATA
        self.routing_engine.subscribe(key, handler)


class Channel:
    _id = 0
    all_instance = {}

    def __init__(self, event, exchange, symbol, frequency, markets: Markets):
        Channel._id += 1
        Channel.all_instance[Channel._id] = self

        self.channel_id = Channel._id
        self.event = event
        self.exchange = exchange
        self.symbol = symbol
        self.frequency = frequency
        self.markets = markets
        self.init()

    def init(self):
        pass

    @abstract_method
    def connect(self):
        pass

    @abstract_method
    def reconnect(self):
        pass

    @abstract_method
    def disconnect(self):
        pass

    @abstract_method
    def is_open(self):
        pass


class Handler:
    def __init__(self, event, exchange, symbol, frequency, markets: Markets):
        self.event = event
        self.exchange = exchange
        self.symbol = symbol
        self.frequency = frequency
        self.markets = markets
        self.init()

    def init(self):
        pass

    @abstract_method
    def process(self, routing_key, recv_time, raw):
        pass


"""
不要阻塞feed_raw()线程
sleep()改成60
"""


if __name__ == '__main__':
    from quant.utils import set_test_mode
    set_test_mode()

    def try_timer():
        tm = Timer()

        def func():
            print('on_timer')
        tm.subscribe(func, 0.1)

        def func():
            print('on_timer2')
        tm.subscribe(func, 0.5)

        i = 1000
        while True:
            i += 0.5
            tm.step(i)
            input(':')

    def perf_timer():
        from quant.utils import perf_test
        tm = Timer()
        def func():
            pass
        tm.subscribe(func, 0.1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)
        tm.subscribe(lambda: None, 1)

        def perf():
            tm.step(tm._last_step+10)
        perf_test(perf)


    mar = Markets()
    mar.add_market('Book', 'Binance', 'btc/usdt.swap')
    mar.add_market('Book', 'Binance', 'btc/usdt.swap')
    mar.add_market('Book', 'Binance', 'btc/usdt.swap')
    mar.add_market('Book', 'Binance', 'btc/usdt.swap')

    mar.channel_selecting.start(50)

    # utils.print_raw(mar)

    while True:
        a = input(':')
        if not a:
            table = mar.channel_selecting.tabulate_lag()
            print(table)
            dic = mar.channel_selecting.dict_server_time_lag()
            for n, v in dic.items():
                print(n, v)
        else:
            for selec in mar.channel_selecting._channel_selectors.values():
                selec.select()

        input(':')








